Skip to content

dist function errors with a callable metric #378

@zdk123

Description

@zdk123

Describe the bug

According to the docstring of ot.dist, the metric object can be be a callable function.

However, L235 clearly expects [only] a string.

Looks like this was broken in commit 0c5899

To Reproduce

Steps to reproduce the behavior:

import ot
import numpy as np

n = 10
a = np.random.randn(n,3)
b = np.random.randn(n,3)

eucl = lambda a, b: np.sum(np.square(a - b))
ot.dist(a, b, metric = eucl)

bug:

    if metric.endswith("minkowski"):
AttributeError: 'function' object has no attribute 'endswith'

Output of the following code snippet:

import ot; print("POT", ot.__version__)
# POT 0.8.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions