Skip to content

Commit

Permalink
first working example. need heavy reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
alchem0x2A committed Jan 29, 2024
1 parent f81e9cd commit 7042270
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 22 deletions.
89 changes: 76 additions & 13 deletions sparc/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,25 +238,33 @@ def _dump_system_state(self):
return system_state

def ensure_socket(self):
# TODO: more ensure directory to other place?
if not self.directory.is_dir():
os.makedirs(self.directory, exist_ok=True)
if not self.use_socket:
return
if self.in_socket is None:
# self.in_socket is actually a SocketServer
socket_name = generate_random_socket_name()
print(f"Creating a socket server with name {socket_name}")
self.in_socket = SPARCSocketServer(
unixsocket=socket_name,
# TODO: make the log fd persistent
log=self.openfile(self._indir(ext=".log", label="socket"), mode="w"),
parent=self,
)
if self.socket_mode == "server":
# TODO: Exception for wrong port
self.in_socket = SPARCSocketServer(
port=self.socket_params["port"],
log=self.openfile(self._indir(ext=".log", label="socket"), mode="w"),
parent=self,
)
else:
socket_name = generate_random_socket_name()
print(f"Creating a socket server with name {socket_name}")
self.in_socket = SPARCSocketServer(
unixsocket=socket_name,
# TODO: make the log fd persistent
log=self.openfile(self._indir(ext=".log", label="socket"), mode="w"),
parent=self,
)
# TODO: add the outbound socket client
# TODO: we may need to check an actual socket server at host:port?!
# At this stage, we will need to wait the actual client to join
if self.out_socket is None:
if self.socket_params["port"] > 0:
if self.socket_mode == "client":
self.out_socket = SPARCSocketClient(
host=self.socket_params["host"],
port=self.socket_params["port"],
Expand All @@ -283,6 +291,26 @@ def __exit__(self, type, value, traceback):
def use_socket(self):
return self.socket_params["use_socket"]

@property
def socket_mode(self):
"""The mode of the socket calculator:
disabled: pure SPARC file IO interface
local: Serves as a local SPARC calculator with socket support
client: Relay SPARC calculation
server: Remote server
"""
if self.use_socket:
if self.socket_params["port"] > 0:
if self.socket_params["server_only"]:
return "server"
else:
return "client"
else:
return "local"
else:
return "disabled"

def _indir(self, ext, label=None, occur=0, d_format="{:02d}"):
return self.sparc_bundle._indir(
ext=ext, label=label, occur=occur, d_format=d_format
Expand Down Expand Up @@ -483,11 +511,15 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change
if param_changed:
system_changes.append("parameters")

if self.use_socket:
if self.socket_mode in ("local", "client"):
self._calculate_with_socket(
atoms=atoms, properties=properties, system_changes=system_changes
)
return

if self.socket_mode == "server":
self._calculate_as_server(atoms=atoms, properties=properties, system_changes=system_changes)
return
self.write_input(self.atoms, properties, system_changes)
self.execute()
self.read_results()
Expand All @@ -506,6 +538,24 @@ def calculate(self, atoms=None, properties=["energy"], system_changes=all_change
self.atoms.get_initial_magnetic_moments()
)

def _calculate_as_server(self, atoms=None, properties=["energy"],
system_changes=all_changes):
"""Use the server component to send instructions to socket
"""
ret, raw_results = self.in_socket.calculate_new_protocol(
atoms=atoms, params=self.parameters
)
self.raw_results = raw_results
if "stress" not in self.results:
virial_from_socket = ret.get("virial", np.zeros(6))
stress_from_socket = -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume()
self.results["stress"] = stress_from_socket
# Energy and forces returned in this case do not need
# resorting, since they are already in the same format
self.results["energy"] = ret["energy"]
self.results["forces"] = ret["forces"]
return

def _calculate_with_socket(
self, atoms=None, properties=["energy"], system_changes=all_changes
):
Expand Down Expand Up @@ -557,7 +607,8 @@ def _calculate_with_socket(
universal_newlines=True,
bufsize=0,
)
ret = self.in_socket.calculate(atoms[self.sort])
# in_socket is a server
ret = self.in_socket.calculate_origin_protocol(atoms[self.sort])
# The results are parsed from file outputs (.static + .out)
# Except for stress, they should be exactly the same as socket returned results
self.read_results() #
Expand All @@ -573,7 +624,6 @@ def _calculate_with_socket(
stress_from_socket = -full_3x3_to_voigt_6_stress(virial_from_socket) / atoms.get_volume()
self.results["stress"] = stress_from_socket
self.system_state = self._dump_system_state()

return

def get_stress(self, atoms=None):
Expand Down Expand Up @@ -723,6 +773,9 @@ def close(self):
if self.in_socket is not None:
self.in_socket.close()

if self.out_socket is not None:
self.out_socket.close()

# import pdb; pdb.set_trace()
# In most cases if in_socket is closed, the SPARC process should also exit
if self.process:
Expand All @@ -736,6 +789,7 @@ def close(self):

# TODO: check if in_socket should be merged
self.in_socket = None
self.out_socket = None
self._reset_process()

def _send_mpi_signal(self, sig):
Expand Down Expand Up @@ -870,6 +924,15 @@ def detect_sparc_version(self):
)
return version


def run_client(self, atoms=None, use_stress=False):
"""Main method to start the client code
"""
if not self.socket_mode == "client":
raise RuntimeError("Cannot use SPARC.run_client if the calculator is not configured in client mode!")

self.out_socket.run(atoms, use_stress)

def detect_socket_compatibility(self):
"""Test if the sparc binary supports socket mode"""
try:
Expand Down
33 changes: 24 additions & 9 deletions sparc/socketio.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ def send_param(self, name, value):
"""
self.log(f"Setup param {name}, {value}")
msg = self.status()
# TODO: see how NEEDINIT works
# if msg == 'NEEDINIT':
# self.sendinit()
# msg = self.status()
assert msg == "READY", msg
# Send message
self.sendmsg("SETPARAM")
Expand Down Expand Up @@ -124,6 +120,8 @@ def recvinit(self):
return super().recvinit()

def calculate_new_protocol(self, atoms, params):
atoms = atoms.copy()
atoms.calc = None
self.log(' calculate with new protocol')
msg = self.status()
# We don't know how NEEDINIT is supposed to work, but some codes
Expand Down Expand Up @@ -213,10 +211,28 @@ def send_atoms_and_params(self, atoms, params):
self.protocol.send_object(pair)
return

# def calculate(self, positions, cell):
# """Fallback protocol to adapt the original i-PI protocol
# """
# return super().calculate(positions, cell)
def calculate_origin_protocol(self, atoms):
"""Send geometry to client and return calculated things as dict.
This will block until client has established connection, then
wait for the client to finish the calculation."""
assert not self._closed

# If we have not established connection yet, we must block
# until the client catches up:
if self.protocol is None:
self._accept()
return self.protocol.calculate(atoms.positions, atoms.cell)

def calculate_new_protocol(self, atoms, params={}):
assert not self._closed

# If we have not established connection yet, we must block
# until the client catches up:
if self.protocol is None:
self._accept()
return self.protocol.calculate_new_protocol(atoms, params)



class SPARCSocketClient(SocketClient):
Expand Down Expand Up @@ -319,7 +335,6 @@ def irun(self, atoms, use_stress=True):
print(recv_atoms, params)
if params != {}:
self.parent_calc.set(**params)
# self.parent_calc.atoms = recv_atoms
# TODO: should we update the atoms directly or keep copy?
atoms = recv_atoms
atoms.calc = self.parent_calc
Expand Down

0 comments on commit 7042270

Please sign in to comment.