Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add optional column alignment in write #117

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions f90nml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def reads(nml_string):
return parser.reads(nml_string)


def write(nml, nml_path, force=False, sort=False):
def write(nml, nml_path, force=False, sort=False, colwidth=0):
"""Save a namelist to disk using either a file object or its file path.

File object usage:
Expand Down Expand Up @@ -79,7 +79,7 @@ def write(nml, nml_path, force=False, sort=False):
else:
nml_in = nml

nml_in.write(nml_path, force=force, sort=sort)
nml_in.write(nml_path, force=force, sort=sort, colwidth=colwidth)


def patch(nml_path, nml_patch, out_path=None):
Expand Down
40 changes: 24 additions & 16 deletions f90nml/namelist.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def default_start_index(self, value):
raise TypeError('default_start_index must be an integer.')
self._default_start_index = value

def write(self, nml_path, force=False, sort=False):
def write(self, nml_path, force=False, sort=False, colwidth=0):
"""Write Namelist to a Fortran 90 namelist file.

>>> nml = f90nml.read('input.nml')
Expand All @@ -402,7 +402,7 @@ def write(self, nml_path, force=False, sort=False):

nml_file = nml_path if nml_is_file else open(nml_path, 'w')
try:
self._writestream(nml_file, sort)
self._writestream(nml_file, sort=sort, colwidth=colwidth)
finally:
if not nml_is_file:
nml_file.close()
Expand Down Expand Up @@ -430,7 +430,7 @@ def groups(self):
for inner_key, inner_value in value.items():
yield (key, inner_key), inner_value

def _writestream(self, nml_file, sort=False):
def _writestream(self, nml_file, sort=False, colwidth=0):
"""Output Namelist to a streamable file object."""
# Reset newline flag
self._newline = False
Expand All @@ -444,11 +444,14 @@ def _writestream(self, nml_file, sort=False):
# Check for repeated namelist records (saved as lists)
if isinstance(grp_vars, list):
for g_vars in grp_vars:
self._write_nmlgrp(grp_name, g_vars, nml_file, sort)
self._write_nmlgrp(grp_name, g_vars, nml_file, sort=sort,
colwidth=colwidth)
else:
self._write_nmlgrp(grp_name, grp_vars, nml_file, sort)
self._write_nmlgrp(grp_name, grp_vars, nml_file, sort=sort,
colwidth=colwidth)

def _write_nmlgrp(self, grp_name, grp_vars, nml_file, sort=False):
def _write_nmlgrp(self, grp_name, grp_vars, nml_file, sort=False,
colwidth=0):
"""Write namelist group to target file."""
if self._newline:
print(file=nml_file)
Expand All @@ -466,13 +469,15 @@ def _write_nmlgrp(self, grp_name, grp_vars, nml_file, sort=False):

v_start = grp_vars.start_index.get(v_name, None)

for v_str in self._var_strings(v_name, v_val, v_start=v_start):
for v_str in self._var_strings(v_name, v_val, v_start=v_start,
colwidth=colwidth):
nml_line = self.indent + '{0}'.format(v_str)
print(nml_line, file=nml_file)

print('/', file=nml_file)

def _var_strings(self, v_name, v_values, v_idx=None, v_start=None):
def _var_strings(self, v_name, v_values, v_idx=None, v_start=None,
colwidth=0):
"""Convert namelist variable to list of fixed-width strings."""
if self.uppercase:
v_name = v_name.upper()
Expand Down Expand Up @@ -504,7 +509,7 @@ def _var_strings(self, v_name, v_values, v_idx=None, v_start=None):
for idx, val in enumerate(v_values, start=i_s):
v_idx_new = v_idx + [idx]
v_strs = self._var_strings(v_name, val, v_idx=v_idx_new,
v_start=v_start)
v_start=v_start, colwidth=colwidth)
var_strs.extend(v_strs)

# Parse derived type contents
Expand All @@ -515,7 +520,8 @@ def _var_strings(self, v_name, v_values, v_idx=None, v_start=None):
v_start_new = v_values.start_index.get(f_name, None)

v_strs = self._var_strings(v_title, f_vals,
v_start=v_start_new)
v_start=v_start_new,
colwidth=colwidth)
var_strs.extend(v_strs)

# Parse an array of derived types
Expand All @@ -533,7 +539,7 @@ def _var_strings(self, v_name, v_values, v_idx=None, v_start=None):

v_title = v_name + '({0})'.format(idx)

v_strs = self._var_strings(v_title, val)
v_strs = self._var_strings(v_title, val, colwidth=colwidth)
var_strs.extend(v_strs)

else:
Expand Down Expand Up @@ -586,7 +592,7 @@ def _var_strings(self, v_name, v_values, v_idx=None, v_start=None):
val_line = ''
for v_val in v_values:

v_header = v_name + v_idx_repr + ' = '
v_header = (v_name + v_idx_repr).ljust(colwidth) + ' = '
# Increase column width if the header exceeds this value
if len(self.indent + v_header) >= self.column_width:
column_width = len(self.indent + v_header) + 1
Expand Down Expand Up @@ -614,12 +620,14 @@ def _var_strings(self, v_name, v_values, v_idx=None, v_start=None):

# Complete the set of values
if val_strs:
var_strs.append('{0}{1} = {2}'
''.format(v_name, v_idx_repr,
val_strs[0]).strip())
var_strs.append('{0}{1}'.format(v_name, v_idx_repr).strip()
.ljust(colwidth)
+ ' = '
+ '{}'.format(val_strs[0]).strip())

for v_str in val_strs[1:]:
var_strs.append(' ' * len(v_header) + v_str)
var_strs.append((' ' * len(v_header)).ljust(colwidth+3)
+ v_str)

return var_strs

Expand Down