Skip to content

Custom cost function? #164

Answered by marcocuturi
prabhant asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @prabhant!

To define a custom cost function, you should create a class that inherits from CostFn, and, at the very least, define a pairwise function that computes the difference between two vectors (at this moment we only handle distance functions between vectors).

The simple example is the Euclidean distance:

class Euclidean(CostFn):
  """Euclidean distance."""
  def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
    """Compute Euclidean norm."""
    return jnp.linalg.norm(x - y)

so for instance you could define your own distance,

from ott.geometry import costs
class MyCost(costs.CostFn):
  """My weird cost."""
  def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by prabhant
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants