From 7042270d664201badc7ca0058e229d1eda6c2324 Mon Sep 17 00:00:00 2001 From: "T.Tian" Date: Tue, 30 Jan 2024 00:37:45 +0800 Subject: [PATCH] first working example. need heavy reformatting --- sparc/calculator.py | 89 ++++++++++++++++++++++++++++++++++++++------- sparc/socketio.py | 33 ++++++++++++----- 2 files changed, 100 insertions(+), 22 deletions(-) diff --git a/sparc/calculator.py b/sparc/calculator.py index 82c66ab..5e41f81 100644 --- a/sparc/calculator.py +++ b/sparc/calculator.py @@ -238,25 +238,33 @@ def _dump_system_state(self): return system_state def ensure_socket(self): + # TODO: more ensure directory to other place? if not self.directory.is_dir(): os.makedirs(self.directory, exist_ok=True) if not self.use_socket: return if self.in_socket is None: - # self.in_socket is actually a SocketServer - socket_name = generate_random_socket_name() - print(f"Creating a socket server with name {socket_name}") - self.in_socket = SPARCSocketServer( - unixsocket=socket_name, - # TODO: make the log fd persistent - log=self.openfile(self._indir(ext=".log", label="socket"), mode="w"), - parent=self, - ) + if self.socket_mode == "server": + # TODO: Exception for wrong port + self.in_socket = SPARCSocketServer( + port=self.socket_params["port"], + log=self.openfile(self._indir(ext=".log", label="socket"), mode="w"), + parent=self, + ) + else: + socket_name = generate_random_socket_name() + print(f"Creating a socket server with name {socket_name}") + self.in_socket = SPARCSocketServer( + unixsocket=socket_name, + # TODO: make the log fd persistent + log=self.openfile(self._indir(ext=".log", label="socket"), mode="w"), + parent=self, + ) # TODO: add the outbound socket client # TODO: we may need to check an actual socket server at host:port?! # At this stage, we will need to wait the actual client to join if self.out_socket is None: - if self.socket_params["port"] > 0: + if self.socket_mode == "client": self.out_socket = SPARCSocketClient( host=self.socket_params["host"], port=self.socket_params["port"], @@ -283,6 +291,26 @@ def __exit__(self, type, value, traceback): def use_socket(self): return self.socket_params["use_socket"] + @property + def socket_mode(self): + """The mode of the socket calculator: + + disabled: pure SPARC file IO interface + local: Serves as a local SPARC calculator with socket support + client: Relay SPARC calculation + server: Remote server + """ + if self.use_socket: + if self.socket_params["port"] > 0: + if self.socket_params["server_only"]: + return "server" + else: + return "client" + else: + return "local" + else: + return "disabled" + def _indir(self, ext, label=None, occur=0, d_format="{:02d}"): return self.sparc_bundle._indir( ext=ext, label=label, occur=occur, d_format=d_format @@ -483,11 +511,15 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change if param_changed: system_changes.append("parameters") - if self.use_socket: + if self.socket_mode in ("local", "client"): self._calculate_with_socket( atoms=atoms, properties=properties, system_changes=system_changes ) return + + if self.socket_mode == "server": + self._calculate_as_server(atoms=atoms, properties=properties, system_changes=system_changes) + return self.write_input(self.atoms, properties, system_changes) self.execute() self.read_results() @@ -506,6 +538,24 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change self.atoms.get_initial_magnetic_moments() ) + def _calculate_as_server(self, atoms=None, properties=["energy"], + system_changes=all_changes): + """Use the server component to send instructions to socket + """ + ret, raw_results = self.in_socket.calculate_new_protocol( + atoms=atoms, params=self.parameters + ) + self.raw_results = raw_results + if "stress" not in self.results: + virial_from_socket = ret.get("virial", np.zeros(6)) + stress_from_socket = -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume() + self.results["stress"] = stress_from_socket + # Energy and forces returned in this case do not need + # resorting, since they are already in the same format + self.results["energy"] = ret["energy"] + self.results["forces"] = ret["forces"] + return + def _calculate_with_socket( self, atoms=None, properties=["energy"], system_changes=all_changes ): @@ -557,7 +607,8 @@ def _calculate_with_socket( universal_newlines=True, bufsize=0, ) - ret = self.in_socket.calculate(atoms[self.sort]) + # in_socket is a server + ret = self.in_socket.calculate_origin_protocol(atoms[self.sort]) # The results are parsed from file outputs (.static + .out) # Except for stress, they should be exactly the same as socket returned results self.read_results() # @@ -573,7 +624,6 @@ def _calculate_with_socket( stress_from_socket = -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume() self.results["stress"] = stress_from_socket self.system_state = self._dump_system_state() - return def get_stress(self, atoms=None): @@ -723,6 +773,9 @@ def close(self): if self.in_socket is not None: self.in_socket.close() + if self.out_socket is not None: + self.out_socket.close() + # import pdb; pdb.set_trace() # In most cases if in_socket is closed, the SPARC process should also exit if self.process: @@ -736,6 +789,7 @@ def close(self): # TODO: check if in_socket should be merged self.in_socket = None + self.out_socket = None self._reset_process() def _send_mpi_signal(self, sig): @@ -870,6 +924,15 @@ def detect_sparc_version(self): ) return version + + def run_client(self, atoms=None, use_stress=False): + """Main method to start the client code + """ + if not self.socket_mode == "client": + raise RuntimeError("Cannot use SPARC.run_client if the calculator is not configured in client mode!") + + self.out_socket.run(atoms, use_stress) + def detect_socket_compatibility(self): """Test if the sparc binary supports socket mode""" try: diff --git a/sparc/socketio.py b/sparc/socketio.py index ea547f7..1ea93f2 100644 --- a/sparc/socketio.py +++ b/sparc/socketio.py @@ -90,10 +90,6 @@ def send_param(self, name, value): """ self.log(f"Setup param {name}, {value}") msg = self.status() - # TODO: see how NEEDINIT works - # if msg == 'NEEDINIT': - # self.sendinit() - # msg = self.status() assert msg == "READY", msg # Send message self.sendmsg("SETPARAM") @@ -124,6 +120,8 @@ def recvinit(self): return super().recvinit() def calculate_new_protocol(self, atoms, params): + atoms = atoms.copy() + atoms.calc = None self.log(' calculate with new protocol') msg = self.status() # We don't know how NEEDINIT is supposed to work, but some codes @@ -213,10 +211,28 @@ def send_atoms_and_params(self, atoms, params): self.protocol.send_object(pair) return - # def calculate(self, positions, cell): - # """Fallback protocol to adapt the original i-PI protocol - # """ - # return super().calculate(positions, cell) + def calculate_origin_protocol(self, atoms): + """Send geometry to client and return calculated things as dict. + + This will block until client has established connection, then + wait for the client to finish the calculation.""" + assert not self._closed + + # If we have not established connection yet, we must block + # until the client catches up: + if self.protocol is None: + self._accept() + return self.protocol.calculate(atoms.positions, atoms.cell) + + def calculate_new_protocol(self, atoms, params={}): + assert not self._closed + + # If we have not established connection yet, we must block + # until the client catches up: + if self.protocol is None: + self._accept() + return self.protocol.calculate_new_protocol(atoms, params) + class SPARCSocketClient(SocketClient): @@ -319,7 +335,6 @@ def irun(self, atoms, use_stress=True): print(recv_atoms, params) if params != {}: self.parent_calc.set(**params) - # self.parent_calc.atoms = recv_atoms # TODO: should we update the atoms directly or keep copy? atoms = recv_atoms atoms.calc = self.parent_calc