diff --git a/src/pydantic_avro/from_avro/class_registery.py b/src/pydantic_avro/from_avro/class_registery.py index adf2204..b9dedf9 100644 --- a/src/pydantic_avro/from_avro/class_registery.py +++ b/src/pydantic_avro/from_avro/class_registery.py @@ -13,7 +13,9 @@ def __new__(cls): def add_class(self, name: str, class_def: str): """Add a class to the registry.""" - self._classes[name] = class_def + formatted_name = self.format_class_name(name) + formatted_class_def = self.replace_class_def_name(name, class_def) + self._classes[formatted_name] = formatted_class_def @property def classes(self) -> dict: @@ -27,3 +29,12 @@ def has_class(self, name: str) -> bool: def clear(self): """Clear all classes from the registry.""" self._classes.clear() + + def format_class_name(self, name: str) -> str: + """Format the class name to be Pythonic.""" + return name.replace("_", "") + + def replace_class_def_name(self, class_name: str, class_def: str) -> str: + """Format the class definition to be Pythonic.""" + formatted_class_name = self.format_class_name(class_name) + return class_def.replace(class_name, formatted_class_name) diff --git a/tests/test_from_avro.py b/tests/test_from_avro.py index d2626c3..95b05dd 100644 --- a/tests/test_from_avro.py +++ b/tests/test_from_avro.py @@ -3,6 +3,7 @@ import pytest from pydantic_avro.from_avro.avro_to_pydantic import avsc_to_pydantic, convert_file +from pydantic_avro.from_avro.class_registery import ClassRegistry def test_avsc_to_pydantic_empty(): @@ -25,6 +26,17 @@ def test_avsc_to_pydantic_missing_fields(): avsc_to_pydantic({"name": "Test", "type": "record"}) +def test_avsc_to_pydantic_class_name_formatting(): + pydantic_code = avsc_to_pydantic( + { + "name": "Test_Name", + "type": "record", + "fields": [], + } + ) + assert "class TestName(BaseModel):\n pass" in pydantic_code + + def test_avsc_to_pydantic_primitive(): pydantic_code = avsc_to_pydantic( { @@ -59,7 +71,10 @@ def test_avsc_to_pydantic_map(): "name": "Test", "type": "record", "fields": [ - {"name": "col1", "type": {"type": "map", "values": "string", "default": {}}}, + { + "name": "col1", + "type": {"type": "map", "values": "string", "default": {}}, + }, ], } ) @@ -73,7 +88,10 @@ def test_avsc_to_pydantic_map_missing_values(): "name": "Test", "type": "record", "fields": [ - {"name": "col1", "type": {"type": "map", "values": None, "default": {}}}, + { + "name": "col1", + "type": {"type": "map", "values": None, "default": {}}, + }, ], } ) @@ -215,7 +233,11 @@ def test_default(): "fields": [ {"name": "col1", "type": "string", "default": "test"}, {"name": "col2_1", "type": ["null", "string"], "default": None}, - {"name": "col2_2", "type": ["string", "null"], "default": "default_str"}, + { + "name": "col2_2", + "type": ["string", "null"], + "default": "default_str", + }, { "name": "col3", "type": {"type": "map", "values": "string"}, @@ -245,7 +267,11 @@ def test_enums(): "fields": [ { "name": "c1", - "type": {"type": "enum", "symbols": ["passed", "failed"], "name": "Status"}, + "type": { + "type": "enum", + "symbols": ["passed", "failed"], + "name": "Status", + }, }, ], } @@ -264,7 +290,11 @@ def test_enums_reuse(): "fields": [ { "name": "c1", - "type": {"type": "enum", "symbols": ["passed", "failed"], "name": "Status"}, + "type": { + "type": "enum", + "symbols": ["passed", "failed"], + "name": "Status", + }, }, {"name": "c2", "type": "Status"}, ], @@ -291,7 +321,12 @@ def test_unions(): { "type": "record", "name": "ARecord", - "fields": [{"name": "values", "type": {"type": "map", "values": "string"}}], + "fields": [ + { + "name": "values", + "type": {"type": "map", "values": "string"}, + } + ], }, ], },