Skip to content

Possible error : transform_labels() function in ot.da.BaseTransport #207

@samarth4149

Description

@samarth4149

Describe the bug

transform_labels() output is not a probability distribution over labels, i.e. does not sum to 1.

To Reproduce

Check following code snippet

Code sample

import ot
import numpy as np

ot_sinkhorn = ot.da.SinkhornTransport()
Xs = np.array([[1, 0], [0, 0]])
ys = np.array([0, 1])
Xt = np.array([[1, 0], [2, 0], [3, 0]])
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
print(ot_sinkhorn.transform_labels(ys))

Output:

[[0.07946862 0.58719805]
 [0.33333333 0.33333333]
 [0.58719805 0.07946861]]

Expected behavior

Expected the rows of output to sum to 1 (labels being a probability distribution over classes)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions