diff --git a/Unpaint/ImportHuggingFaceModelDialog.xaml b/Unpaint/ImportHuggingFaceModelDialog.xaml index b80f656..78a955a 100644 --- a/Unpaint/ImportHuggingFaceModelDialog.xaml +++ b/Unpaint/ImportHuggingFaceModelDialog.xaml @@ -17,7 +17,7 @@ - You can import Stable Diffusion models from HuggingFace.co below. Please note that only ONNX models are supported by Unpaint. + You can import Stable Diffusion models from HuggingFace.co below. Please note that only ONNX models are supported by Unpaint. diff --git a/Unpaint/InferenceOptionsView.xaml b/Unpaint/InferenceOptionsView.xaml index bafb5ef..f5951dd 100644 --- a/Unpaint/InferenceOptionsView.xaml +++ b/Unpaint/InferenceOptionsView.xaml @@ -29,6 +29,7 @@ + @@ -59,10 +60,15 @@ Click="{x:Bind ViewModel.ManageModels}"/> - - + + + + + @@ -73,29 +79,29 @@ - - - - + - - + - - + + diff --git a/Unpaint/InferenceOptionsViewModel.cpp b/Unpaint/InferenceOptionsViewModel.cpp index 1af1888..e59fbee 100644 --- a/Unpaint/InferenceOptionsViewModel.cpp +++ b/Unpaint/InferenceOptionsViewModel.cpp @@ -3,6 +3,7 @@ #include "InferenceOptionsViewModel.g.cpp" using namespace Axodox::Infrastructure; +using namespace Axodox::MachineLearning; using namespace winrt::Windows::Graphics; using namespace winrt::Windows::UI::Xaml::Data; @@ -14,6 +15,7 @@ namespace winrt::Unpaint::implementation _modelRepository(dependencies.resolve()), _deviceInformation(dependencies.resolve()), _models(single_threaded_observable_vector()), + _schedulers(single_threaded_observable_vector()), _resolutions(single_threaded_observable_vector()), _selectedResolutionIndex(1), _selectedModelIndex(0), @@ -22,6 +24,10 @@ namespace winrt::Unpaint::implementation //Initialize models OnModelChanged(); + //Initialize schedulers + _schedulers.Append(L"Euler A"); + _schedulers.Append(L"DPM++ 2M Karras"); + //Initialize resolutions _resolutions.Append(SizeInt32{ 1024, 1024 }); _resolutions.Append(SizeInt32{ 768, 768 }); @@ -48,6 +54,24 @@ namespace winrt::Unpaint::implementation _propertyChanged(*this, PropertyChangedEventArgs(L"SelectedModelIndex")); } + winrt::Windows::Foundation::Collections::IObservableVector InferenceOptionsViewModel::Schedulers() + { + return _schedulers; + } + + int32_t InferenceOptionsViewModel::SelectedSchedulerIndex() + { + return int32_t(*_unpaintState->Scheduler); + } + + void InferenceOptionsViewModel::SelectedSchedulerIndex(int32_t value) + { + if (value == SelectedSchedulerIndex()) return; + + _unpaintState->Scheduler = StableDiffusionSchedulerKind(value); + _propertyChanged(*this, PropertyChangedEventArgs(L"SelectedSchedulerIndex")); + } + winrt::Windows::Foundation::Collections::IObservableVector InferenceOptionsViewModel::Resolutions() { return _resolutions; diff --git a/Unpaint/InferenceOptionsViewModel.h b/Unpaint/InferenceOptionsViewModel.h index 9379a16..f8296a3 100644 --- a/Unpaint/InferenceOptionsViewModel.h +++ b/Unpaint/InferenceOptionsViewModel.h @@ -14,6 +14,11 @@ namespace winrt::Unpaint::implementation int32_t SelectedModelIndex(); void SelectedModelIndex(int32_t value); + winrt::Windows::Foundation::Collections::IObservableVector Schedulers(); + + int32_t SelectedSchedulerIndex(); + void SelectedSchedulerIndex(int32_t value); + winrt::Windows::Foundation::Collections::IObservableVector Resolutions(); int32_t SelectedResolutionIndex(); @@ -51,6 +56,8 @@ namespace winrt::Unpaint::implementation Windows::Foundation::Collections::IObservableVector _models; int32_t _selectedModelIndex; + Windows::Foundation::Collections::IObservableVector _schedulers; + Windows::Foundation::Collections::IObservableVector _resolutions; int32_t _selectedResolutionIndex; diff --git a/Unpaint/InferenceView.xaml b/Unpaint/InferenceView.xaml index de99be1..200be6e 100644 --- a/Unpaint/InferenceView.xaml +++ b/Unpaint/InferenceView.xaml @@ -333,7 +333,7 @@ - + diff --git a/Unpaint/InferenceViewModel.cpp b/Unpaint/InferenceViewModel.cpp index facdf92..12acca2 100644 --- a/Unpaint/InferenceViewModel.cpp +++ b/Unpaint/InferenceViewModel.cpp @@ -53,7 +53,8 @@ namespace winrt::Unpaint::implementation _featureMask(nullptr), _inputResolution({ 0, 0 }), _isAutoGenerationEnabled(false), - _safetyStrikes(0) + _safetyStrikes(0), + _modelChangedSubscription(_unpaintState->ModelId.ValueChanged(event_handler{ this, &InferenceViewModel::OnModelChanged })) { } ProjectViewModel InferenceViewModel::Project() @@ -74,6 +75,12 @@ namespace winrt::Unpaint::implementation _propertyChanged(*this, PropertyChangedEventArgs(L"SelectedModeIndex")); } + bool InferenceViewModel::IsModeSelectable() + { + auto model = _modelRepository->GetModel(*_unpaintState->ModelId); + return model && !model->IsXL; + } + bool InferenceViewModel::IsSettingsLocked() { return _unpaintState->IsSettingsLocked; @@ -250,6 +257,7 @@ namespace winrt::Unpaint::implementation .Mode = _unpaintState->InferenceMode, .PositivePrompt = _unpaintState->PositivePrompt->empty() ? StableDiffusionInferenceTask::PositivePromptPlaceholder : *_unpaintState->PositivePrompt, .NegativePrompt = _unpaintState->NegativePrompt->empty() ? StableDiffusionInferenceTask::NegativePromptPlaceholder : *_unpaintState->NegativePrompt, + .Scheduler = _unpaintState->Scheduler, .Resolution = { uint32_t(_unpaintState->Resolution->Width), uint32_t(_unpaintState->Resolution->Height) }, .GuidanceStrength = _unpaintState->GuidanceStrength, .DenoisingStrength = _unpaintState->DenoisingStrength, @@ -258,7 +266,7 @@ namespace winrt::Unpaint::implementation .BatchSize = _unpaintState->IsBatchGenerationEnabled ? *_unpaintState->BatchSize : 1, .IsSafeModeEnabled = _unpaintState->IsSafeModeEnabled, .IsSafetyCheckerEnabled = _unpaintState->IsSafetyCheckerEnabled, - .ModelId = _unpaintState->ModelId + .ModelId = _unpaintState->ModelId }; filesystem::path inputPath; @@ -461,4 +469,14 @@ namespace winrt::Unpaint::implementation { _propertyChanged.remove(token); } + + void InferenceViewModel::OnModelChanged(OptionPropertyBase* /*option*/) + { + if (!IsModeSelectable()) + { + SelectedModeIndex(0); + } + + _propertyChanged(*this, PropertyChangedEventArgs(L"IsModeSelectable")); + } } diff --git a/Unpaint/InferenceViewModel.h b/Unpaint/InferenceViewModel.h index 5e36f0d..7234d19 100644 --- a/Unpaint/InferenceViewModel.h +++ b/Unpaint/InferenceViewModel.h @@ -19,6 +19,8 @@ namespace winrt::Unpaint::implementation int32_t SelectedModeIndex(); void SelectedModeIndex(int32_t value); + bool IsModeSelectable(); + bool IsSettingsLocked(); void IsSettingsLocked(bool value); @@ -87,6 +89,10 @@ namespace winrt::Unpaint::implementation bool _isAutoGenerationEnabled; bool _hasSafetyCheckFailed; uint32_t _safetyStrikes; + + Axodox::Infrastructure::event_subscription _modelChangedSubscription; + + void OnModelChanged(OptionPropertyBase* option); }; } diff --git a/Unpaint/ModelRepository.cpp b/Unpaint/ModelRepository.cpp index a782314..5fcce19 100644 --- a/Unpaint/ModelRepository.cpp +++ b/Unpaint/ModelRepository.cpp @@ -122,11 +122,14 @@ namespace winrt::Unpaint auto metadata = try_parse_json(*text); if (!metadata) continue; + auto isXL = filesystem::exists(file.path().parent_path() / "text_encoder_2", ec); + models.emplace(ModelInfo{ *metadata->Id, metadata->Name->empty() ? *metadata->Id : *metadata->Name, *metadata->Website, - *metadata->AccessToken + *metadata->AccessToken, + isXL }); } @@ -192,7 +195,8 @@ namespace winrt::Unpaint return ModelViewModel{ .Id = to_hstring(Id), .Name = to_hstring(Name), - .Uri = to_hstring(Website) + .Uri = to_hstring(Website), + .IsXL = IsXL }; } } \ No newline at end of file diff --git a/Unpaint/ModelRepository.h b/Unpaint/ModelRepository.h index c559eb1..7203b7a 100644 --- a/Unpaint/ModelRepository.h +++ b/Unpaint/ModelRepository.h @@ -23,6 +23,7 @@ namespace winrt::Unpaint std::string Name; std::string Website; std::string AccessToken; + bool IsXL; auto operator<=>(const ModelInfo&) const = default; bool operator<(const ModelInfo&) const = default; diff --git a/Unpaint/SettingsViewModel.cpp b/Unpaint/SettingsViewModel.cpp index a88b379..e13ea67 100644 --- a/Unpaint/SettingsViewModel.cpp +++ b/Unpaint/SettingsViewModel.cpp @@ -30,12 +30,9 @@ namespace winrt::Unpaint::implementation bool SettingsViewModel::AreUnsafeOptionsEnabled() { -#ifdef NDEBUG - //Sorry mates I do not trust you this much... - return false; -#else - return true; -#endif +#pragma warning(suppress: 4996) + auto devMode = getenv("UNPAINT_DEV"); + return devMode && strcmp(devMode, "1") == 0; } bool SettingsViewModel::IsSafeModeEnabled() diff --git a/Unpaint/StableDiffusionModelExecutor.cpp b/Unpaint/StableDiffusionModelExecutor.cpp index 79b7708..ca09a65 100644 --- a/Unpaint/StableDiffusionModelExecutor.cpp +++ b/Unpaint/StableDiffusionModelExecutor.cpp @@ -277,13 +277,13 @@ namespace winrt::Unpaint for (auto weight : encodedPositivePrompt[0].Weights) textEmbbedding.Weights.push_back(weight); ScheduledTensor tensor{ task.SamplingSteps }; - trivial_map, shared_ptr> embeddingBuffer; + trivial_map, shared_ptr> embeddingBuffer; for (auto i = 0u; i < task.SamplingSteps; i++) { auto& concatenatedTensor = embeddingBuffer[{ encodedNegativePrompt[i].Tensor.get(), encodedPositivePrompt[i].Tensor.get() }]; if (!concatenatedTensor) { - concatenatedTensor = make_shared(encodedNegativePrompt[i].Tensor->Concat(*encodedPositivePrompt[i].Tensor)); + concatenatedTensor = make_shared(encodedNegativePrompt[i].Tensor->Concat(*encodedPositivePrompt[i].Tensor)); } tensor[i] = concatenatedTensor; @@ -330,7 +330,8 @@ namespace winrt::Unpaint .TextEmbeddings = inputs.TextEmbeddings, .LatentInput = inputs.InputImage, .MaskInput = inputs.InputMask, - .DenoisingStrength = task.Mode == InferenceMode::Modify ? task.DenoisingStrength : 1.f + .DenoisingStrength = task.Mode == InferenceMode::Modify ? task.DenoisingStrength : 1.f, + .Scheduler = task.Scheduler }; async.update_state("Running denoiser..."); @@ -366,8 +367,11 @@ namespace winrt::Unpaint void StableDiffusionModelExecutor::RunSafetyCheck(std::vector& images, Axodox::Threading::async_operation_source& async) { + static const char* safetyCheckedId = "safety_checker\\model.onnx"; + if (!_modelFiles.contains(safetyCheckedId)) return; + async.update_state(NAN, "Loading safety checker..."); - SafetyChecker safetyChecker{ *_onnxEnvironment, GetModelFile("safety_checker\\model.onnx") }; + SafetyChecker safetyChecker{ *_onnxEnvironment, GetModelFile(safetyCheckedId) }; async.update_state(NAN, "Checking safety..."); for (auto& image : images) diff --git a/Unpaint/StableDiffusionModelExecutor.h b/Unpaint/StableDiffusionModelExecutor.h index 278dad7..1027961 100644 --- a/Unpaint/StableDiffusionModelExecutor.h +++ b/Unpaint/StableDiffusionModelExecutor.h @@ -21,6 +21,7 @@ namespace winrt::Unpaint InferenceMode Mode; std::string PositivePrompt, NegativePrompt; + Axodox::MachineLearning::StableDiffusionSchedulerKind Scheduler; DirectX::XMUINT2 Resolution; float GuidanceStrength; diff --git a/Unpaint/Unpaint.vcxproj b/Unpaint/Unpaint.vcxproj index 7a24d6b..6767fad 100644 --- a/Unpaint/Unpaint.vcxproj +++ b/Unpaint/Unpaint.vcxproj @@ -1,8 +1,8 @@ + + - - Release @@ -425,12 +425,12 @@ - - - + + + @@ -439,14 +439,14 @@ - - - - - + + + + + \ No newline at end of file diff --git a/Unpaint/UnpaintState.cpp b/Unpaint/UnpaintState.cpp index 67324e5..9afc0ba 100644 --- a/Unpaint/UnpaintState.cpp +++ b/Unpaint/UnpaintState.cpp @@ -2,6 +2,7 @@ #include "UnpaintState.h" using namespace Axodox::Infrastructure; +using namespace Axodox::MachineLearning; using namespace winrt::Windows::Graphics; namespace winrt::Unpaint @@ -16,6 +17,7 @@ namespace winrt::Unpaint IsFeatureExtractorPinned("Inference.IsFeatureExtractorPinned", false), AdapterIndex("Inference.AdapterIndex", 0), ModelId("Inference.ModelId"), + Scheduler("Inference.Scheduler", StableDiffusionSchedulerKind::EulerAncestral), StateChanged(_events) { auto deviceInformation = dependencies.resolve(); @@ -49,17 +51,13 @@ namespace winrt::Unpaint IsFeatureExtractorPinned.ValueChanged(no_revoke, event_handler{ this, &UnpaintState::OnStateChanged }); AdapterIndex.ValueChanged(no_revoke, event_handler{ this, &UnpaintState::OnStateChanged }); ModelId.ValueChanged(no_revoke, event_handler{ this, &UnpaintState::OnStateChanged }); + Scheduler.ValueChanged(no_revoke, event_handler{ this, &UnpaintState::OnStateChanged }); //Initialize properties if (*Resolution == SizeInt32{0, 0}) { Resolution = deviceInformation->IsDeviceXbox() ? SizeInt32{ 512, 512 } : SizeInt32{ 768, 768 }; } - -#ifdef NDEBUG - IsSafeModeEnabled = true; - IsSafetyCheckerEnabled = true; -#endif } void UnpaintState::OnStateChanged(OptionPropertyBase* property) diff --git a/Unpaint/UnpaintState.h b/Unpaint/UnpaintState.h index 540b80a..2f93808 100644 --- a/Unpaint/UnpaintState.h +++ b/Unpaint/UnpaintState.h @@ -53,6 +53,7 @@ namespace winrt::Unpaint PersistentOptionProperty IsFeatureExtractorPinned; PersistentOptionProperty AdapterIndex; PersistentOptionProperty ModelId; + PersistentOptionProperty Scheduler; #pragma endregion Axodox::Infrastructure::event_publisher StateChanged; diff --git a/Unpaint/ViewModels.idl b/Unpaint/ViewModels.idl index 3e81747..95dccc2 100644 --- a/Unpaint/ViewModels.idl +++ b/Unpaint/ViewModels.idl @@ -27,6 +27,7 @@ namespace Unpaint String Id; String Name; String Uri; + Boolean IsXL; }; [default_interface] @@ -81,6 +82,7 @@ namespace Unpaint ProjectViewModel Project{ get; }; Int32 SelectedModeIndex; + Boolean IsModeSelectable{ get; }; Boolean IsSettingsLocked; Boolean IsJumpingToLatestImage; @@ -163,6 +165,9 @@ namespace Unpaint IObservableVector Models{ get; }; Int32 SelectedModelIndex; + IObservableVector Schedulers{ get; }; + Int32 SelectedSchedulerIndex; + IObservableVector Resolutions{ get; }; Int32 SelectedResolutionIndex; diff --git a/Unpaint/packages.config b/Unpaint/packages.config index 95e21cb..987e919 100644 --- a/Unpaint/packages.config +++ b/Unpaint/packages.config @@ -1,9 +1,9 @@  - + - - + +