diff --git a/src/main.rs b/src/main.rs index a2eb818..efdc1af 100644 --- a/src/main.rs +++ b/src/main.rs @@ -184,20 +184,20 @@ mod tests { let matches = cmd.try_get_matches_from(vec![ "oxigration", - "generate", + "migrate", "-d", - "test_schemas/", + "tests/schemas/baseline/", "-c", "postgresql://test@localhost/test", ]); assert!(matches.is_ok()); let matches = matches.unwrap(); - assert_eq!(matches.subcommand_name(), Some("generate")); + assert_eq!(matches.subcommand_name(), Some("migrate")); if let Some(sub_matches) = matches.subcommand_matches("generate") { assert_eq!( sub_matches.get_one::("dir").unwrap(), - "test_schemas/" + "tests/schemas/baseline/" ); assert_eq!( sub_matches.get_one::("connection").unwrap(), @@ -212,20 +212,20 @@ mod tests { let matches = cmd.try_get_matches_from(vec![ "oxigration", - "migrate", + "generate", "-d", - "test_schemas/", + "tests/schemas/generated/", "-c", "postgresql://test@localhost/test", ]); assert!(matches.is_ok()); let matches = matches.unwrap(); - assert_eq!(matches.subcommand_name(), Some("migrate")); + assert_eq!(matches.subcommand_name(), Some("generate")); if let Some(sub_matches) = matches.subcommand_matches("migrate") { assert_eq!( sub_matches.get_one::("dir").unwrap(), - "test_schemas/" + "tests/schemas/generated/" ); assert_eq!( sub_matches.get_one::("connection").unwrap(), diff --git a/src/source_code.rs b/src/source_code.rs index cdad2f5..ec0ee72 100644 --- a/src/source_code.rs +++ b/src/source_code.rs @@ -2,7 +2,7 @@ use crate::utils::topsort::topo_sort; use core::ops::ControlFlow; use indexmap::IndexMap; use sqlparser::ast::{ObjectName, Statement, Visitor}; -use sqlparser::dialect::GenericDialect; +use sqlparser::dialect::PostgreSqlDialect; use sqlparser::parser::Parser; use std::collections::{HashMap, HashSet}; use std::error::Error; @@ -26,7 +26,7 @@ pub struct DatabaseObject { /// Additional properties associated with the database object. pub properties: HashMap, /// The parsed SQL content of the database object. - pub parsed_content: Vec, + pub parsed_content: Option, } impl DatabaseObject { /// Creates a new DatabaseObject with the given parameters. @@ -35,7 +35,7 @@ impl DatabaseObject { value: String, mut dependencies: HashSet, properties: HashMap, - parsed_content: Vec, + parsed_content: Option, ) -> Self { // Check if properties contain a "depends" key and add its value to dependencies if let Some(depends) = properties.get("depends") { @@ -68,7 +68,17 @@ impl SqlVisitor { } fn visit_object_name(&mut self, name: &ObjectName) { - self.object_name = name.to_string(); + // If the ObjectName is in the format of schema.name, return only name + self.object_name = name + .0 + .last() + .map(|ident| ident.value.clone()) + .unwrap_or_default(); + self.schema_name = name + .0 + .first() + .map(|ident| ident.value.clone()) + .unwrap_or_default(); } } impl Visitor for SqlVisitor { @@ -186,18 +196,17 @@ pub fn read_source_code( // Parse the SQL statements in the file let parsed_stmts = parse_change_stmts(&contents, "//// CHANGE", "GO", "name"); // Iterate over the parsed statements - for (_, stmt) in parsed_stmts { + for (_, mut stmt) in parsed_stmts { // Build a relational object from the parsed statement - match build_relational_object( + match relational_object_conformance( file_path, schema_name, object_type, &contents, - Some(&stmt), + &mut stmt, ) { - Ok(relational_object) => { - object_info - .insert(relational_object.change_name.clone(), relational_object); + Ok(_) => { + object_info.insert(stmt.change_name.clone(), stmt); } Err(e) => return Err(e), } @@ -212,19 +221,19 @@ pub fn read_source_code( Ok(ordered_object_info) } -/// Builds a `DatabaseObject` from the given parameters. +/// Updates a `DatabaseObject` with the given parameters. /// /// This function takes several parameters including the file path, schema name, object type, -/// contents of the SQL file, and an optional `DatabaseObject` statement. It parses the SQL +/// contents of the SQL file, and a mutable reference to a `DatabaseObject`. It parses the SQL /// content to extract the first SQL object and uses a visitor to traverse the SQL statement /// and gather necessary information such as the object name and schema name. /// /// The function then constructs a key for the `DatabaseObject` based on the schema name, -/// object type, file name, and change name (if provided). It also extracts dependencies -/// and properties from the optional `DatabaseObject` statement. +/// object type, file name, and change name. It also updates the dependencies and properties +/// of the `DatabaseObject` accordingly. /// /// If the object name extracted from the SQL content does not match the file name, an error -/// is returned. Otherwise, a new `DatabaseObject` is created and returned. +/// is returned. Otherwise, the existing `DatabaseObject` is updated with the new information. /// /// # Arguments /// @@ -232,67 +241,72 @@ pub fn read_source_code( /// * `schema_name` - A string slice representing the schema name. /// * `object_type` - A string slice representing the type of the object (e.g., table, view). /// * `contents` - A string slice containing the contents of the SQL file. -/// * `stmt` - An optional reference to a `DatabaseObject` statement. +/// * `stmt` - A mutable reference to a `DatabaseObject` to be updated. /// /// # Returns /// /// This function returns a `Result` containing: -/// * `Ok(DatabaseObject)` - A `DatabaseObject` constructed from the provided parameters. +/// * `Ok(())` - If the `DatabaseObject` was successfully updated. /// * `Err(Box)` - An error if the object name does not match the file name or if /// there are issues parsing the SQL content. -fn build_relational_object( +fn relational_object_conformance( file_path: &Path, schema_name: &str, object_type: &str, contents: &str, - stmt: Option<&DatabaseObject>, -) -> Result> { - let dialect = GenericDialect {}; - let parsed_content = Parser::parse_sql(&dialect, &stmt.map_or(contents, |s| &s.value))?; - let first_object = parsed_content - .first() - .ok_or("No objects found in parsed content")?; - - let mut visitor = SqlVisitor::new(); - visitor.pre_visit_statement(first_object); // Use pre_visit_statement method - + stmt: &mut DatabaseObject, +) -> Result<(), Box> { + // Extract the file name from the file path let file_name = file_path .file_stem() .and_then(|stem| stem.to_str()) .map(|s| s.to_string()) .ok_or_else(|| format!("Failed to extract file stem from path: {:?}", file_path))?; - if file_name != visitor.object_name { + // Parse the SQL content to extract the first SQL object + let dialect = PostgreSqlDialect {}; + let parsed_content = match Parser::parse_sql(&dialect, &stmt.value) { + Ok(content) => content + .first() + .cloned() // Clone the first element to extend its lifetime + .ok_or("No objects found in parsed content")?, + Err(e) => return Err(Box::new(e)), + }; + + // Use a visitor to traverse the SQL statement and gather necessary information + let mut visitor = SqlVisitor::new(); + visitor.pre_visit_statement(&parsed_content); // Use pre_visit_statement method + + // Check if the file name matches the object name + if file_name != stmt.change_name { return Err(format!( "Object name '{}' in file does not match name '{}' in SQL", file_name, visitor.object_name ) .into()); } - let schema_name = if visitor.schema_name.is_empty() { - schema_name.to_string() - } else { - visitor.schema_name.clone() - }; - let key = match stmt { - Some(stmt) => format!( - "{}.{}.{}.{}", - schema_name, object_type, file_name, stmt.change_name - ), - None => format!("{}.{}.{}.{}", schema_name, object_type, file_name, "root"), - }; + // Check if the schema name matches the object schema + if visitor.schema_name != schema_name { + return Err(format!( + "Schema name '{}' in file does not match schema name '{}' in SQL", + schema_name, visitor.schema_name + ) + .into()); + } + + // Create a unique identifier for the DatabaseObject + let key = format!( + "{}.{}.{}.{}", + schema_name, object_type, file_name, stmt.change_name + ); - let dependencies = stmt.map_or_else(HashSet::new, |s| s.dependencies.clone()); - let properties = stmt.map_or_else(HashMap::new, |s| s.properties.clone()); + // Update the existing DatabaseObject + stmt.change_name = key; + stmt.value = contents.to_string(); + stmt.parsed_content = Some(parsed_content); - Ok(DatabaseObject::new( - key, - stmt.map_or_else(|| contents.to_string(), |s| s.value.clone()), - dependencies, - properties, - parsed_content, - )) + Ok(()) } /// Parses a string containing multiple statements delimited by start and end delimiters, @@ -332,7 +346,7 @@ fn parse_change_stmts( value.trim().to_string(), dependencies.clone(), properties.clone(), - vec![], + None, ), ); dependencies.insert(current_name.clone()); @@ -350,7 +364,7 @@ fn parse_change_stmts( value.trim().to_string(), dependencies.clone(), properties.clone(), - vec![], + None, ), ); dependencies.insert(root_name.clone()); @@ -374,7 +388,7 @@ fn parse_change_stmts( value.trim().to_string(), dependencies, properties, - vec![], + None, ), ); } @@ -459,7 +473,7 @@ mod tests { #[test] fn test_read_source_code_with_valid_directory() { let dir = tempdir().unwrap(); - let file_path = dir.path().join("schema1").join("tables").join("table1.sql"); + let file_path = dir.path().join("schema1").join("table").join("table1.sql"); fs::create_dir_all(file_path.parent().unwrap()).unwrap(); let mut file = File::create(&file_path).unwrap(); writeln!(file, "CREATE TABLE table1 (id INT);").unwrap(); @@ -468,7 +482,7 @@ mod tests { assert!(result.is_ok()); let object_info = result.unwrap(); assert_eq!(object_info.len(), 1); - assert!(object_info.contains_key("schema1.tables.table1.root0")); + assert!(object_info.contains_key("schema1.table.table1.root0")); } #[test] @@ -480,8 +494,8 @@ mod tests { #[test] fn test_read_source_code_with_multiple_files() { let dir = tempdir().unwrap(); - let file_path1 = dir.path().join("schema1").join("tables").join("table1.sql"); - let file_path2 = dir.path().join("schema1").join("views").join("view1.sql"); + let file_path1 = dir.path().join("schema1").join("table").join("table1.sql"); + let file_path2 = dir.path().join("schema1").join("view").join("view1.sql"); fs::create_dir_all(file_path1.parent().unwrap()).unwrap(); fs::create_dir_all(file_path2.parent().unwrap()).unwrap(); let mut file1 = File::create(&file_path1).unwrap(); @@ -493,17 +507,17 @@ mod tests { assert!(result.is_ok()); let object_info = result.unwrap(); assert_eq!(object_info.len(), 2); - assert!(object_info.contains_key("schema1.tables.table1.root0")); - assert!(object_info.contains_key("schema1.views.view1.root0")); + assert!(object_info.contains_key("schema1.table.table1.root0")); + assert!(object_info.contains_key("schema1.view.view1.root0")); } #[test] fn test_read_source_code_with_dependencies() { let dir = tempdir().unwrap(); - let file_path1 = dir.path().join("schema1").join("tables").join("table1.sql"); - let file_path2 = dir.path().join("schema1").join("tables").join("table2.sql"); - let file_path3 = dir.path().join("schema1").join("tables").join("table3.sql"); - let file_path4 = dir.path().join("schema1").join("tables").join("table4.sql"); + let file_path1 = dir.path().join("schema1").join("table").join("table1.sql"); + let file_path2 = dir.path().join("schema1").join("table").join("table2.sql"); + let file_path3 = dir.path().join("schema1").join("table").join("table3.sql"); + let file_path4 = dir.path().join("schema1").join("table").join("table4.sql"); fs::create_dir_all(file_path1.parent().unwrap()).unwrap(); fs::create_dir_all(file_path2.parent().unwrap()).unwrap(); fs::create_dir_all(file_path3.parent().unwrap()).unwrap(); @@ -537,19 +551,19 @@ mod tests { assert!(result.is_ok()); let object_info = result.unwrap(); assert_eq!(object_info.len(), 4); - assert!(object_info.contains_key("schema1.tables.table1.change1")); - assert!(object_info.contains_key("schema1.tables.table2.change2")); - assert!(object_info.contains_key("schema1.tables.table3.change3")); - assert!(object_info.contains_key("schema1.tables.table4.change4")); + assert!(object_info.contains_key("schema1.table.table1.change1")); + assert!(object_info.contains_key("schema1.table.table2.change2")); + assert!(object_info.contains_key("schema1.table.table3.change3")); + assert!(object_info.contains_key("schema1.table.table4.change4")); // Assert dependencies - let change2 = object_info.get("schema1.tables.table2.change2").unwrap(); + let change2 = object_info.get("schema1.table.table2.change2").unwrap(); assert!(change2.dependencies.contains("table1")); - let change3 = object_info.get("schema1.tables.table3.change3").unwrap(); + let change3 = object_info.get("schema1.table.table3.change3").unwrap(); assert!(change3.dependencies.contains("change1")); - let change4 = object_info.get("schema1.tables.table4.change4").unwrap(); + let change4 = object_info.get("schema1.table.table4.change4").unwrap(); assert!(change4.dependencies.contains("table2")); assert!(change4.dependencies.contains("change3")); } @@ -557,7 +571,7 @@ mod tests { #[test] fn test_file_name_matches_object_name() { let dir = tempfile::tempdir().unwrap(); - let file_path1 = dir.path().join("schema1/tables/change1.sql"); + let file_path1 = dir.path().join("schema1/table/change1.sql"); fs::create_dir_all(file_path1.parent().unwrap()).unwrap(); let mut file1 = File::create(&file_path1).unwrap(); writeln!( @@ -573,11 +587,30 @@ mod tests { .contains("Object name 'change1' in file does not match name 'table1' in SQL")); } + #[test] + fn test_schema_name_matches_object_schema() { + let dir = tempfile::tempdir().unwrap(); + let file_path1 = dir.path().join("schema1/table/table1.sql"); + fs::create_dir_all(file_path1.parent().unwrap()).unwrap(); + let mut file1 = File::create(&file_path1).unwrap(); + writeln!( + file1, + "//// CHANGE name=table1\nCREATE TABLE schema2.table1 (id INT);\nGO" + ) + .unwrap(); + + let result = read_source_code(dir.path().to_str().unwrap()); + assert!(result.is_err()); + let error_message = result.unwrap_err().to_string(); + assert!(error_message + .contains("Schema name 'schema1' in file does not match schema name 'schema2' in SQL")); + } + #[test] fn test_circular_dependency() { let dir = tempfile::tempdir().unwrap(); - let file_path1 = dir.path().join("schema1/tables/table1.sql"); - let file_path2 = dir.path().join("schema1/tables/table2.sql"); + let file_path1 = dir.path().join("schema1/table/table1.sql"); + let file_path2 = dir.path().join("schema1/table/table2.sql"); fs::create_dir_all(file_path1.parent().unwrap()).unwrap(); fs::create_dir_all(file_path2.parent().unwrap()).unwrap(); @@ -603,9 +636,9 @@ mod tests { #[test] fn test_circular_dependency_three_objects() { let dir = tempfile::tempdir().unwrap(); - let file_path1 = dir.path().join("schema1/tables/table1.sql"); - let file_path2 = dir.path().join("schema1/tables/table2.sql"); - let file_path3 = dir.path().join("schema1/tables/table3.sql"); + let file_path1 = dir.path().join("schema1/table/table1.sql"); + let file_path2 = dir.path().join("schema1/table/table2.sql"); + let file_path3 = dir.path().join("schema1/table/table3.sql"); fs::create_dir_all(file_path1.parent().unwrap()).unwrap(); fs::create_dir_all(file_path2.parent().unwrap()).unwrap(); @@ -678,4 +711,11 @@ mod tests { assert!(parsed_stmts.contains_key("root0")); assert!(parsed_stmts.contains_key("root1")); } + + #[test] + fn test_read_source_code_with_one_schema() { + let source_code = read_source_code("tests/schemas/baseline/").unwrap(); + assert_eq!(source_code.len(), 1); + assert!(source_code.contains_key("schema1.table.table1.root0")); + } }