Skip to content

Commit

Permalink
Sync with open source and update version string to 3.0.7post1 (#131)
Browse files Browse the repository at this point in the history
Co-authored-by: Zeming Lin <[email protected]>
  • Loading branch information
ebetica and Zeming Lin authored Oct 23, 2024
1 parent 2b5defd commit 39a3a6c
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CONTRIBUTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ python examples/raw_forwards.py
`examples/esmprotein.ipynb` works. Remember to skip running the first cell - it will reinstall stock esm instead of your deployed version.

3. Ensure `examples/generate.ipynb` works. Note this notebook will require a node with a GPU that can fit ESM3 small open.

4. Ensure
2 changes: 1 addition & 1 deletion esm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.7"
__version__ = "3.0.7post1"
8 changes: 8 additions & 0 deletions esm/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ class GenerationConfig:
condition_on_coordinates_only: bool = True


@define
class InverseFoldingConfig:
invalid_ids: Sequence[int] = []
schedule: str = "cosine"
num_steps: int = 1
temperature: float = 1.0


## Low Level Endpoint Types
@define
class SamplingTrackConfig:
Expand Down
69 changes: 68 additions & 1 deletion esm/sdk/forge.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def retry_if_specific_error(exception):
We only retry on specific errors.
Currently we retry for 502 (bad gateway) and 429 (rate limit)
"""
return isinstance(exception, ESMProteinError) and exception.error_code in {429, 502}
return isinstance(exception, ESMProteinError) and exception.error_code in {
429,
502,
504,
}


def log_retry_attempt(retry_state):
Expand All @@ -51,6 +55,69 @@ def log_retry_attempt(retry_state):
)


class FoldForgeInferenceClient:
def __init__(
self,
url: str = "https://forge.evolutionaryscale.ai",
token: str = "",
request_timeout: int | None = None,
):
if token == "":
raise RuntimeError(
"Please provide a token to connect to Forge via token=YOUR_API_TOKEN_HERE"
)
self.url = url
self.token = token
self.headers = {"Authorization": f"Bearer {self.token}"}
self.request_timeout = request_timeout

def fold(
self,
model_name: str,
sequence: str,
potential_sequence_of_concern: bool,
) -> torch.Tensor | ESMProteinError:
request = {
"model": model_name,
"sequence": sequence,
}
try:
data = self._post(
"fold",
request,
potential_sequence_of_concern,
)
except ESMProteinError as e:
return e

return data["coordinates"]

def _post(self, endpoint, request, potential_sequence_of_concern):
request["potential_sequence_of_concern"] = potential_sequence_of_concern

model_name_url = request["model"] if request["model"] != "esm3" else "api"
response = requests.post(
urljoin(self.url, f"/{model_name_url}/v1/{endpoint}"),
json=request,
headers=self.headers,
timeout=self.request_timeout,
)

if not response.ok:
raise ESMProteinError(
error_code=response.status_code,
error_msg=f"Failure in {endpoint}: {response.text}",
)

data = response.json()
# Nextjs puts outputs dict under "data" key.
# Lift it up for easier downstream processing.
if "outputs" not in data and "data" in data:
data = data["data"]

return data


class ESM3ForgeInferenceClient(ESM3InferenceClient):
def __init__(
self,
Expand Down
2 changes: 0 additions & 2 deletions esm/utils/structure/protein_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,6 @@ def from_pdb(cls, path: PathOrBuffer, id: str | None = None) -> "ProteinComplex"
if len(chain) == 0:
continue
chains.append(ProteinChain.from_atomarray(chain, id))
print("IS INF")
print(np.isinf(chains[-1].atom37_positions).any())
return ProteinComplex.from_chains(chains)

@classmethod
Expand Down
66 changes: 66 additions & 0 deletions examples/folding_inverse_folding_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import cast

import numpy as np

from examples.local_generate import get_sample_protein
from esm.sdk.api import (
ESM3InferenceClient,
ESMProtein,
GenerationConfig,
)
from esm.sdk.forge import FoldForgeInferenceClient


def convert_none_to_nan(data):
"""Recursively convert None values in any deeply nested structure (e.g., list of lists of lists) to np.nan."""
if isinstance(data, list):
return [convert_none_to_nan(x) for x in data]
elif data is None:
return np.nan
else:
return data


def are_allclose_with_nan(A, B, rtol=1e-5, atol=1e-2):
B = convert_none_to_nan(B)

A = np.array(A)
B = np.array(B)

if A.shape != B.shape:
raise ValueError("A and B must have the same shape")

nan_mask_A = np.isnan(A)
nan_mask_B = np.isnan(B)

if not np.array_equal(nan_mask_A, nan_mask_B):
return False

return np.allclose(A[~nan_mask_A], B[~nan_mask_B], rtol=rtol, atol=atol)


def main(fold_client: FoldForgeInferenceClient, esm3_client: ESM3InferenceClient):
# Folding
protein = get_sample_protein()
sequence_length = len(protein.sequence) # type: ignore
num_steps = int(sequence_length / 16)
protein.coordinates = None
protein.function_annotations = None
protein.sasa = None
# Folding with esm3 client
folded_protein = cast(
ESMProtein,
esm3_client.generate(
protein,
GenerationConfig(
track="structure", schedule="cosine", num_steps=num_steps, temperature=0
),
),
)
# Folding with folding client
coordinates = fold_client.fold(
"esm3",
protein.sequence, # type:ignore
potential_sequence_of_concern=False,
)
assert are_allclose_with_nan(folded_protein.coordinates, coordinates)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "esm"
version = "3.0.7"
version = "3.0.7post1"
description = "EvolutionaryScale open model repository"
readme = "README.md"
requires-python = ">=3.10"
Expand Down

0 comments on commit 39a3a6c

Please sign in to comment.