Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions doc/sphinx/examples/math/run_mapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Shamrock 3D unit vector generator
=======================================

This example shows how to use the unit vector generator
Comment on lines +2 to +5

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The docstring appears to be copied from another example and is not relevant to this script. It describes a "3D unit vector generator," whereas this example demonstrates arbitrary function mapping. Updating the docstring will make the example's purpose clear to users.

Shamrock arbitrary function mapping
===================================

This example shows how to map a set of uniformly distributed random points
to a distribution following an arbitrary function.

"""

# %%

import matplotlib.pyplot as plt # plots
import numpy as np # sqrt & arctan2

import shamrock
import sympy as sp
from scipy.special import erfinv, erf
Comment on lines +13 to +15

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The modules shamrock, erfinv, and erf are imported but are not used anywhere in the script. It's good practice to remove unused imports to keep the code clean and avoid confusion.

import sympy as sp



#random set of points between 0 and 1
np.random.seed(111)
points = np.random.rand(1000)[:]

range_start = (0,3)
range_end = (0,1) # must be between 0 and 1 because of the normalization

#define the function exp(-x^2) using sympy

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment incorrectly describes the function as exp(-x^2). The actual function being defined is |sin(-x^2)|. This should be corrected to avoid confusion.

# define the function f(x) = |sin(-x^2)| using sympy

x = sp.symbols('x')
f = sp.Abs(sp.sin(-x**2))

# Numerical integration of f over range_start
primitive = []
primitive_x = []
accum = 0
dx = (range_start[1] - range_start[0]) / 100
for x_val in np.linspace(range_start[0], range_start[1], 100):
primitive.append(accum)
primitive_x.append(x_val)
accum += f.subs(x, x_val).evalf() * dx

primitive = np.array(primitive)
primitive_x = np.array(primitive_x)
Comment on lines +29 to +40

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The numerical integration is implemented with a Python for loop, which is inefficient and less accurate than available library functions. Using sympy.lambdify for faster function evaluation and scipy.integrate.cumulative_trapezoid for integration will make the code more performant, idiomatic, and accurate. Also, the dx calculation is slightly off for the number of points generated by linspace.

# Lambdify the sympy function for faster evaluation with numpy
f_numpy = sp.lambdify(x, f, 'numpy')

# Numerical integration of f over range_start
# Use more points for a better approximation of the integral
num_points = 1001
primitive_x = np.linspace(range_start[0], range_start[1], num_points)
f_vals = f_numpy(primitive_x)
primitive = cumulative_trapezoid(f_vals, primitive_x, initial=0)


# normalize f so that primitive[-1] = 1
norm = primitive[-1]
primitive = primitive / norm
f = f / norm

print(f"primitive = {primitive}")
print(f"primitive_x = {primitive_x}")
Comment on lines +47 to +48

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements output large arrays to the console, which is generally not desired in gallery examples. Consider removing them to keep the output clean.


# plot f
plt.figure()
x_plot = np.linspace(range_start[0], range_start[1], 100)
f_plot = [f.subs(x, x_val).evalf() for x_val in x_plot]
plt.plot(x_plot, f_plot)
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title("f(x) = exp(-x^2)")
Comment on lines +53 to +57

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The plot title is incorrect and evaluation of the function inside a list comprehension is inefficient. The title should reflect the actual function being plotted, |sin(-x^2)|. For performance, it's better to evaluate the function on the whole x_plot array at once, especially after lambdifying the sympy expression.

f_plot = f_numpy(x_plot)
plt.plot(x_plot, f_plot)
plt.xlabel("x")
plt.ylabel("f(x)")
plt.title("f(x) = |sin(-x^2)|")


# plot primitive
plt.figure()
plt.plot(primitive_x, primitive)
plt.xlabel("x")
plt.ylabel("primitive(x)")
plt.title("primitive(x) = integral(f(x))")

# plot finv
plt.figure()
plt.plot(primitive, primitive_x)
plt.xlabel("x")
plt.ylabel("finv(x)")
plt.title("finv(x) = inverse(primitive(x))")

#interpolate primitive using scipy.interpolate.interp1d
from scipy.interpolate import interp1d

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should be placed at the top of the file, following PEP 8 conventions. This makes it easier to see all dependencies at a glance. Please move this import to the top with the others.

mapping_interp = interp1d(primitive, primitive_x, kind='linear')

points_mapped = [mapping_interp(point) for point in points]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a list comprehension here is inefficient. The interp1d object can directly accept a NumPy array, which is much faster. Vectorizing this operation is more idiomatic and performant.

points_mapped = mapping_interp(points)


print(f"points = {points}")
print(f"points_mapped = {points_mapped}")
Comment on lines +79 to +80

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements output large arrays to the console. For a gallery example, this is usually not necessary and clutters the output. It's better to remove them.


plt.figure()
hist_r, bins_r = np.histogram(points, bins=101, density=True, range=range_end)
r = np.linspace(bins_r[0], bins_r[-1], 101)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable r is defined here but it is not used in the subsequent plotting code for the first histogram. It should be removed to avoid confusion.


plt.bar(bins_r[:-1], hist_r, np.diff(bins_r), alpha=0.5)
plt.xlabel("$r$")
plt.ylabel("$f(r)$")

plt.figure()
hist_r, bins_r = np.histogram(points_mapped, bins=101, density=True, range=range_start)
r = np.linspace(bins_r[0], bins_r[-1], 101)

plt.bar(bins_r[:-1], hist_r, np.diff(bins_r), alpha=0.5)
plt.plot(r, [f.subs(x, x_val).evalf() for x_val in r])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Evaluating the function in a list comprehension is inefficient. Since f has been lambdified to f_numpy, you can pass the numpy array r directly for a significant performance improvement.

plt.plot(r, f_numpy(r))

plt.xlabel("$r$")
plt.ylabel("$f(r)$")

plt.show()
Loading