diff --git a/src/ostorlab/runtimes/local/models/alembic/versions/b31a27c8584f_add_domain_asset_model.py b/src/ostorlab/runtimes/local/models/alembic/versions/b31a27c8584f_add_domain_asset_model.py new file mode 100644 index 000000000..5a982f8bd --- /dev/null +++ b/src/ostorlab/runtimes/local/models/alembic/versions/b31a27c8584f_add_domain_asset_model.py @@ -0,0 +1,49 @@ +"""Add domain asset model + +Revision ID: b31a27c8584f +Revises: 3aa6ae380275 +Create Date: 2024-06-13 17:05:48.425291 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "b31a27c8584f" +down_revision = "3aa6ae380275" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "domain_asset", + sa.Column("id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["id"], ["asset.id"], name=op.f("fk_domain_asset_id_asset") + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_domain_asset")), + ) + op.create_table( + "domain_name", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(length=255), nullable=True), + sa.Column("domain_asset_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["domain_asset_id"], + ["domain_asset.id"], + name=op.f("fk_domain_name_domain_asset_id_domain_asset"), + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_domain_name")), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("domain_name") + op.drop_table("domain_asset") + # ### end Alembic commands ### diff --git a/src/ostorlab/runtimes/local/models/models.py b/src/ostorlab/runtimes/local/models/models.py index 0ac9d0cfa..f51393257 100644 --- a/src/ostorlab/runtimes/local/models/models.py +++ b/src/ostorlab/runtimes/local/models/models.py @@ -1118,3 +1118,78 @@ def delete(network_id: int) -> None: session.query(Network).filter_by(id=network_id).delete() session.query(IPRange).filter_by(network_asset_id=network_id).delete() session.commit() + + +class DomainName(Base): + __tablename__ = "domain_name" + id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True) + name = sqlalchemy.Column(sqlalchemy.String(255)) + domain_asset_id = sqlalchemy.Column( + sqlalchemy.Integer, sqlalchemy.ForeignKey("domain_asset.id") + ) + + @staticmethod + def create(name: str, domain_asset_id: Optional[int] = None) -> "DomainName": + """Persist the domain name information in the database. + + Args: + name: The domain name. + domain_asset_id: The domain asset id. + + Returns: + DomainName object. + """ + with Database() as session: + domain_name = DomainName(name=name, domain_asset_id=domain_asset_id) + session.add(domain_name) + session.commit() + return domain_name + + +class DomainAsset(Asset): + __tablename__ = "domain_asset" + id = sqlalchemy.Column( + sqlalchemy.Integer, sqlalchemy.ForeignKey("asset.id"), primary_key=True + ) + + __mapper_args__ = { + "polymorphic_identity": "domain_asset", + } + + @staticmethod + def create( + domains: List[Dict[str, str]], scan_id: Optional[int] = None + ) -> "DomainAsset": + """Persist the domain asset information in the database. + + Args: + domains: list of domain names. + scan_id: The scan id. + + Returns: + DomainAsset object. + """ + with Database() as session: + domain_asset_instance = DomainAsset(scan_id=scan_id) + session.add(domain_asset_instance) + session.commit() + + for domain in domains: + DomainName.create( + name=domain.get("name"), domain_asset_id=domain_asset_instance.id + ) + return domain_asset_instance + + @staticmethod + def delete(domain_asset_id: int) -> None: + """Delete the domain asset information from the database. + + Args: + domain_asset_id: The domain asset id. + """ + with Database() as session: + session.query(DomainAsset).filter_by(id=domain_asset_id).delete() + session.query(DomainName).filter_by( + domain_asset_id=domain_asset_id + ).delete() + session.commit() diff --git a/src/ostorlab/serve_app/oxo.py b/src/ostorlab/serve_app/oxo.py index 6be678a7e..dc951a782 100644 --- a/src/ostorlab/serve_app/oxo.py +++ b/src/ostorlab/serve_app/oxo.py @@ -31,6 +31,7 @@ from ostorlab.assets import ipv4 as ipv4_address_asset from ostorlab.assets import ipv6 as ipv6_address_asset from ostorlab.assets import link as link_asset +from ostorlab.assets import domain_name as domain_name_asset from ostorlab.assets import asset as ostorlab_asset DEFAULT_NUMBER_ELEMENTS = 15 @@ -372,6 +373,9 @@ def mutate( if asset.ip is not None: new_asset = models.Network.create(networks=asset.ip) created_assets.append(new_asset) + if asset.domain is not None: + new_asset = models.DomainAsset.create(domains=asset.domain) + created_assets.append(new_asset) if len(errors) > 0: error_messages = "\n".join(errors) raise graphql.GraphQLError(f"Invalid assets: {error_messages}") @@ -396,6 +400,8 @@ def _validate(asset: types.OxoAssetInputType) -> Optional[str]: assets.append(asset.link) if asset.ip is not None: assets.append(asset.ip) + if asset.domain is not None: + assets.append(asset.domain) if len(assets) == 0: return f"Asset {asset} input is missing target." @@ -654,6 +660,16 @@ def _prepare_assets(asset_ids: List[int]) -> List[ostorlab_asset.Asset]: scan_assets.append( link_asset.Link(url=link.url, method=link.method) ) + elif asset.type == "domain_asset": + domains = ( + session.query(models.DomainName) + .filter_by(domain_asset_id=asset.id) + .all() + ) + for domain in domains: + scan_assets.append( + domain_name_asset.DomainName(name=domain.name) + ) else: raise graphql.GraphQLError("Unsupported asset type.") diff --git a/src/ostorlab/serve_app/types.py b/src/ostorlab/serve_app/types.py index e67aa79c0..987d15132 100644 --- a/src/ostorlab/serve_app/types.py +++ b/src/ostorlab/serve_app/types.py @@ -328,6 +328,43 @@ def resolve_networks(self, info) -> List[OxoIPRangeAssetType]: return [OxoIPRangeAssetType(host=ip.host, mask=ip.mask) for ip in ips] +class OxoDomainNameAssetType(graphene_sqlalchemy.SQLAlchemyObjectType): + class Meta: + model = models.DomainName + only_fields = "name" + + +class OxoDomainNameAssetsType(graphene_sqlalchemy.SQLAlchemyObjectType): + domain_names = graphene.List(OxoDomainNameAssetType, required=False) + + class Meta: + model = models.DomainAsset + only_fields = ("id",) + + def resolve_domain_names( + self, info: graphql_base.ResolveInfo + ) -> List[OxoDomainNameAssetType]: + """Resolve domain names query. + + Args: + self: The domain asset object. + info: GraphQL resolve info. + + Returns: + List of domain names. + """ + with models.Database() as session: + domain_names = ( + session.query(models.DomainName) + .filter_by(domain_asset_id=self.id) + .all() + ) + return [ + OxoDomainNameAssetType(name=domain_name.name) + for domain_name in domain_names + ] + + class OxoAssetType(graphene.Union): class Meta: model = models.Asset @@ -338,6 +375,7 @@ class Meta: OxoIOSStoreAssetType, OxoUrlsAssetType, OxoNetworkAssetType, + OxoDomainNameAssetsType, ) @@ -785,6 +823,10 @@ class OxoLinkInputType(graphene.InputObjectType): method = graphene.String(required=False, default_value="GET") +class OxoDomainNameInputType(graphene.InputObjectType): + name = graphene.String(required=True) + + class OxoAssetInputType(graphene.InputObjectType): android_apk_file = graphene.List(OxoAndroidFileAssetInputType) android_aab_file = graphene.List(OxoAndroidFileAssetInputType) @@ -793,6 +835,7 @@ class OxoAssetInputType(graphene.InputObjectType): ios_store = graphene.List(OxoIOSStoreAssetInputType) link = graphene.List(OxoLinkInputType) ip = graphene.List(OxoIPRangeInputType) + domain = graphene.List(OxoDomainNameInputType) class AgentArgumentInputType(graphene.InputObjectType): diff --git a/tests/conftest.py b/tests/conftest.py index c3c8171dd..125258266 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1111,6 +1111,18 @@ def url_asset(mocker: plugin.MockerFixture, db_engine_path: str) -> models.Urls: return asset +@pytest.fixture +def domain_asset( + mocker: plugin.MockerFixture, db_engine_path: str +) -> models.DomainName: + """Create a DomainName asset.""" + mocker.patch.object(models, "ENGINE_URL", db_engine_path) + asset = models.DomainAsset.create( + domains=[{"name": "google.com"}, {"name": "tesla.com"}] + ) + return asset + + @pytest.fixture def android_file_asset( mocker: plugin.MockerFixture, db_engine_path: str @@ -1233,3 +1245,23 @@ def run_scan_mock(mocker: plugin.MockerFixture) -> None: return_value=True, ) mocker.patch("ostorlab.runtimes.local.runtime.LocalRuntime._inject_assets") + + +@pytest.fixture +def run_scan_mock2(mocker: plugin.MockerFixture) -> None: + """Mock functions required to run a scan.""" + mocker.patch( + "ostorlab.cli.docker_requirements_checker.is_docker_installed", + return_value=True, + ) + mocker.patch( + "ostorlab.cli.docker_requirements_checker.is_docker_working", return_value=True + ) + mocker.patch( + "ostorlab.cli.docker_requirements_checker.is_swarm_initialized", + return_value=True, + ) + mocker.patch("docker.from_env") + mocker.patch( + "ostorlab.runtimes.local.runtime.LocalRuntime.can_run", return_value=True + ) diff --git a/tests/runtimes/local/models/models_test.py b/tests/runtimes/local/models/models_test.py index b4bc9b906..3786f3797 100644 --- a/tests/runtimes/local/models/models_test.py +++ b/tests/runtimes/local/models/models_test.py @@ -332,6 +332,28 @@ def testAssetModels_whenCreateNetwork_assetCreated( assert ips[1].mask == "32" +def testAssetModels_whenCreateDomainAsset_assetCreated( + mocker: plugin.MockerFixture, db_engine_path: str +) -> None: + """Ensure we correctly persist the domain information.""" + mocker.patch.object(models, "ENGINE_URL", db_engine_path) + models.DomainAsset.create( + domains=[{"name": "domain1.test.ma"}, {"name": "domain2.test.ma"}] + ) + + with models.Database() as session: + assert session.query(models.DomainAsset).count() == 1 + domain_id = session.query(models.DomainAsset).all()[0].id + domain_names = ( + session.query(models.DomainName) + .filter(models.DomainName.domain_asset_id == domain_id) + .all() + ) + assert len(domain_names) == 2 + assert domain_names[0].name == "domain1.test.ma" + assert domain_names[1].name == "domain2.test.ma" + + def testAssetModels_whenCreateUrl_assetCreated( mocker: plugin.MockerFixture, db_engine_path: str ) -> None: diff --git a/tests/serve_app/oxo_test.py b/tests/serve_app/oxo_test.py index 01d8b7ec4..b0c86437c 100644 --- a/tests/serve_app/oxo_test.py +++ b/tests/serve_app/oxo_test.py @@ -1421,6 +1421,58 @@ def testCreateAsset_network_createsNewAsset( assert networks[1].mask == "32" +def testCreateAsset_domain_createsNewAsset( + authenticated_flask_client: testing.FlaskClient, clean_db: None +) -> None: + """Ensure the domain asset is created successfully through the createAssets API.""" + del clean_db + query = """ + mutation createDomain($assets: [OxoAssetInputType]!) { + createAssets(assets: $assets) { + assets { + ... on OxoDomainNameAssetsType { + id + domainNames { + name + } + } + } + } + } + """ + + resp = authenticated_flask_client.post( + "/graphql", + json={ + "query": query, + "variables": { + "assets": [ + {"domain": [{"name": "www.google.com"}, {"name": "www.tesla.com"}]} + ] + }, + }, + ) + + assert resp.status_code == 200, resp.get_json() + asset_data = resp.get_json()["data"]["createAssets"]["assets"][0] + assert asset_data["id"] is not None + assert asset_data["domainNames"] == [ + {"name": "www.google.com"}, + {"name": "www.tesla.com"}, + ] + with models.Database() as session: + assert session.query(models.DomainAsset).count() == 1 + domain_asset_id = session.query(models.DomainAsset).first().id + domains = ( + session.query(models.DomainName) + .filter_by(domain_asset_id=domain_asset_id) + .all() + ) + assert len(domains) == 2 + assert domains[0].name == "www.google.com" + assert domains[1].name == "www.tesla.com" + + def testCreateAsset_androidApkFile_createsNewAsset( authenticated_flask_client: testing.FlaskClient, clean_db: None ) -> None: @@ -2243,6 +2295,70 @@ def testRunScanMutation_whenNetworkAsset_shouldRunScan( assert args["assets"][1].mask == "24" +def testRunScanMutation_whenDomainAsset_shouldRunScan( + authenticated_flask_client: testing.FlaskClient, + agent_group_nmap: models.AgentGroup, + domain_asset: models.DomainAsset, + scan: models.Scan, + mocker: plugin.MockerFixture, + run_scan_mock2: None, +) -> None: + """Test RunScanMutation for Domain asset.""" + scan_mock = mocker.patch( + "ostorlab.runtimes.local.runtime.LocalRuntime.scan", return_value=scan + ) + query = """ + mutation RunScan($scan: OxoAgentScanInputType!) { + runScan( + scan: $scan + ) { + scan { + id + title + progress + assets { + ... on OxoDomainNameAssetsType { + id + domainNames { + name + } + } + } + } + } + } + """ + variables = { + "scan": { + "title": "Test Scan Domain Asset", + "assetIds": [domain_asset.id], + "agentGroupId": agent_group_nmap.id, + }, + } + + response = authenticated_flask_client.post( + "/graphql", json={"query": query, "variables": variables} + ) + + assert response.status_code == 200, response.get_json() + res_scan = response.get_json()["data"]["runScan"]["scan"] + assert int(res_scan["id"]) == scan.id + assert res_scan["title"] == scan.title + assert res_scan["progress"] == scan.progress.name + assert len(res_scan["assets"]) == 1 + assert int(res_scan["assets"][0]["id"]) == domain_asset.id + assert res_scan["assets"][0]["domainNames"] == [ + {"name": "google.com"}, + {"name": "tesla.com"}, + ] + args = scan_mock.call_args[1] + assert args["title"] == "Test Scan Domain Asset" + assert args["agent_group_definition"].agents[0].key == "agent/ostorlab/nmap" + assert len(args["assets"]) == 2 + assert args["assets"][0].name == "google.com" + assert args["assets"][1].name == "tesla.com" + + def testRunScanMutation_whenUrl_shouldRunScan( authenticated_flask_client: testing.FlaskClient, agent_group_nmap: models.AgentGroup,